---
title: "LLM with Streaming"
description: "Building an LLM with streaming output"
---


        <Card
          title="View on Github"
          icon="github" href="https://github.com/basetenlabs/truss-examples/tree/main/03-llm-with-streaming">
        </Card>

In this example, we go through a Truss that serves an LLM, and streams the output to the client.

# Why Streaming?

For certain ML models, generations can take a long  time. Especially with LLMs, a long output could take
10 - 20 seconds to generate. However, because LLMs generate tokens in sequence, useful output can be
made available to users sooner. To support this, in Truss, we support streaming output. In this example,
we build a Truss that streams the output of the Falcon-7B model.

# Set up the imports and key constants

In this example, we use the HuggingFace transformers library to build a text generation model.

```python model/model.py
from threading import Thread
from typing import Dict

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    TextIteratorStreamer,
)

```
We use the instruct version of the Falcon-7B model, and have some defaults
for inference parameters.

```python model/model.py
CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95


```
# Define the load function

In the `load` function of the Truss, we implement logic
involved in downloading the model and loading it into memory.

```python model/model.py
class Model:
    def __init__(self, **kwargs) -> None:
        self.tokenizer = None
        self.model = None

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
```


```python model/model.py
        self.tokenizer.pad_token = self.tokenizer.eos_token_id
        self.model = AutoModelForCausalLM.from_pretrained(
            CHECKPOINT,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto",
        )

```
# Define the predict function

In the `predict` function of the Truss, we implement the actual
inference logic. The two main steps are:
* Tokenize the input
* Call the model's `generate` function, ensuring that we pass a
`TextIteratorStreamer`. This is what gives us streaming output, and
and also do this in a Thread, so that it does not block the main
invocation.
* Return a generator that iterates over the `TextIteratorStreamer` object

```python model/model.py
    def predict(self, request: Dict) -> Dict:
        prompt = request.pop("prompt")
        inputs = self.tokenizer(
            prompt, return_tensors="pt", max_length=512, truncation=True, padding=True
        )
        input_ids = inputs["input_ids"].to("cuda")

```
Instantiate the Streamer object, which we'll later use for
returning the output to users.

```python model/model.py
        streamer = TextIteratorStreamer(self.tokenizer)
        generation_config = GenerationConfig(
            temperature=1,
            top_p=DEFAULT_TOP_P,
            top_k=40,
        )

```
When creating the generation parameters, ensure to pass the `streamer` object
that we created previously.

```python model/model.py
        with torch.no_grad():
            generation_kwargs = {
                "input_ids": input_ids,
                "generation_config": generation_config,
                "return_dict_in_generate": True,
                "output_scores": True,
                "pad_token_id": self.tokenizer.eos_token_id,
                "max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
                "streamer": streamer,
            }

```
Spawn a thread to run the generation, so that it does not block the main
thread.

```python model/model.py
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()

```
In Truss, the way to achieve streaming output is to return a generator
that yields content. In this example, we yield the output of the `streamer`,
which produces output and yields it until the generation is complete.

We define this `inner` function to create our generator.

```python model/model.py
            def inner():
                for text in streamer:
                    yield text
                thread.join()

            return inner()
```

# Setting up the config.yaml

Running Falcon 7B requires torch, transformers,
and a few other related libraries.

```yaml config.yaml
model_name: "LLM with Streaming"
model_metadata:
    example_model_input: {"prompt": "what is the meaning of life"}
requirements:
- torch==2.0.1
- peft==0.4.0
- scipy==1.11.1
- sentencepiece==0.1.99
- accelerate==0.21.0
- bitsandbytes==0.41.1
- einops==0.6.1
- transformers==4.31.0
```
## Configure resources for Falcon

Note that we need an A10G to run this model.

```yaml config.yaml
resources:
  cpu: "3"
  memory: 14Gi
  use_gpu: true
  accelerator: A10G
```

<RequestExample>
```python model/model.py
from threading import Thread
from typing import Dict

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    TextIteratorStreamer,
)

CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95


class Model:
    def __init__(self, **kwargs) -> None:
        self.tokenizer = None
        self.model = None

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
        self.tokenizer.pad_token = self.tokenizer.eos_token_id
        self.model = AutoModelForCausalLM.from_pretrained(
            CHECKPOINT,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto",
        )

    def predict(self, request: Dict) -> Dict:
        prompt = request.pop("prompt")
        inputs = self.tokenizer(
            prompt, return_tensors="pt", max_length=512, truncation=True, padding=True
        )
        input_ids = inputs["input_ids"].to("cuda")

        streamer = TextIteratorStreamer(self.tokenizer)
        generation_config = GenerationConfig(
            temperature=1,
            top_p=DEFAULT_TOP_P,
            top_k=40,
        )

        with torch.no_grad():
            generation_kwargs = {
                "input_ids": input_ids,
                "generation_config": generation_config,
                "return_dict_in_generate": True,
                "output_scores": True,
                "pad_token_id": self.tokenizer.eos_token_id,
                "max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
                "streamer": streamer,
            }

            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()

            def inner():
                for text in streamer:
                    yield text
                thread.join()

            return inner()
```
```yaml config.yaml
model_name: "LLM with Streaming"
model_metadata:
    example_model_input: {"prompt": "what is the meaning of life"}
requirements:
- torch==2.0.1
- peft==0.4.0
- scipy==1.11.1
- sentencepiece==0.1.99
- accelerate==0.21.0
- bitsandbytes==0.41.1
- einops==0.6.1
- transformers==4.31.0
resources:
  cpu: "3"
  memory: 14Gi
  use_gpu: true
  accelerator: A10G
```
</RequestExample>
